Introduction to SHAP¶
Centro Singular de Investigación en TecnoloxĆas Intelixentes (CiTIUS)
Universidade de Santiago de Compostela
ETSE-USC, Campus Vida, Santiago de Compostela, Spain
Mr. Pablo Miguel Perez-Ferreiro
Centro Singular de Investigación en TecnoloxĆas Intelixentes (CiTIUS)
Universidade de Santiago de Compostela
ETSE-USC, Campus Vida, Santiago de Compostela, Spain
27 January 2026
Explainable and Trustworthy AI

1. Introduction¶
This interactive tutorial includes supplementary material for the first TXAI in the Lab hands-on session (I1. "Introduction to SHAP") in the Subject Explainable and Trustworthy AI (Master in Artificial Intelligence). The session is leaded by Jose M. Alonso-Moral and Pablo Miguel Perez-Ferreiro at USC, Samuel SuƔrez Marcote) at UDC, and David Nicholas Olivieri Cecchi at UVigo.
2. Settings¶
In this section, we prepare the software needed to run the notebook. Please abstain from changing anything in this section unless you're prompted to by your teachers, as it may break the notebook's functionality.
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
import os
os.system("pip install simplenlg --quiet")
os.system("pip install tabulate --quiet")
os.system("pip install numba==0.59.1 --quiet")
os.system("pip install salib==1.3.3 --quiet")
os.system("pip install numpy==1.26.4 --quiet")
os.system("pip install shap==0.46.0 --quiet")
os.system("pip install interpret==0.5.0 --quiet")
0
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
import warnings
warnings.filterwarnings('ignore')
# Loading plot tool (for ploting fuzzy sets and rules)
import matplotlib.pyplot as plt
# Loading pandas for their DataFrames and some management functions
import pandas as pd
# Loading seaborn for plotting
import seaborn as sns
# Loading numpy for utilities
import numpy as np
# Loading lib to deal with arff files
from scipy.io.arff import loadarff
# Loading sklearn and several of its modules: they will allow us to build and validate the models whose interpretability we will study
import sklearn
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import cross_validate
from sklearn.metrics import classification_report
from sklearn.tree import export_text
# Loading InterpretML, a library specialized in offering explainable by design models.
from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())
from interpret.glassbox import ExplainableBoostingClassifier
from interpret import show
# Loading a library to handle SHAP Values
import shap
# Loading Graphviz in order to plot Decision Trees
import graphviz
## this code is developed by Jose Maria Alonso-Moral
## Auxiliary functions for the notebook:
# Plots a Pareto fron between two variables (x, y), labeled (labelx, labely), with the x-axis' limits defined (minx, maxx)
def plot_pareto_front(x,y,n,labelx,labely,minx,maxx):
plt.title("Pareto Front")
plt.ylabel(labelx)
plt.xlabel(labely)
plt.axis([minx, maxx, 0, 1])
c=["ro","bo","go","rs","bs","gs","r*","b*","g*","r+","b+","g+"]
for m in n:
m_idx= n.index(m)
plt.plot(x[m_idx], y[m_idx], c[m_idx], label=m)
plt.grid(True)
plt.legend()
plt.show()
# Extracts the length of SHAP-based explanations; used on this practical as a surrogate interpretability metric
def get_shap_explanation_length(single_lower_triangular_interactions, indexes=None, th=0.9):
# Calculate the cumulative sum and absolute value using numpy functions
shap_cumsum = np.cumsum(np.abs(single_lower_triangular_interactions))
# Normalize the cumulative sum
normalised_shap_cumsum = shap_cumsum / shap_cumsum[-1]
# Find the index of the first element that exceeds the threshold using a loop
first_above_idx = 0
for i, val in enumerate(normalised_shap_cumsum):
if val > th:
first_above_idx = i
break
# Calculate the SHAP explanation length
shap_expl_length = first_above_idx
return shap_expl_length
3. Interpretable systems: from the ground up¶
In this section, we will go through the process of building an interpretable system step-by-step, working with both interpretable-by-design and opaque models. In this sense, we will:
- Import and pre-process a suitable dataset.
- Build and validate several ML models for it.
- Explore the properties of the built models, and perform a first attempt at interpreting those which are transparent through direct inspection.
- Utilize SHAP Values as a post-hoc explainability tool that allows us to overcome some of the problems brought by direct inspection.
- Perform a preliminary evaluation of the trade-offs between performance and explainability that the field of TXAI usually involves.
3.1. Loading data and preliminary exploration¶
We will be working with the Pima Indians Diabetes Dataset in the usual Weka arff format.
A classical example dataset for machine learning, it contains 768 instances, all of them describing females of age coming from the Pima Indian heritage. Its intended usage is binary classification, predicting the onset of diabetes according to 8 features defined by the World Health Organization:
- Number of times pregnant
- Plasma glucose concentration
- Diastolic blood pressure
- Triceps skin fold thickness
- Two hour serum insulin
- Body mass index
- Diabetes pedigree function
- Age
We will begin by importing it from your workspace, and then we will pre-process the data and explore it.
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Importing the train and test sets from the arff files, and setting the class and attribute names.
file_train = 'testlib/PIMA/PIMA.train.0.arff'
file_test = 'testlib/PIMA/PIMA.test.0.arff'
with open(file_train) as f:
train_data_value, attributes = loadarff(f)
f.close()
with open(file_test) as f:
test_data_value, attributes = loadarff(f)
f.close()
pima_class_names= ['tested_negative','tested_positive']
pima_fnames = ["Number of times pregnant", "Plasma glucose concentration", "Diastolic blood_pressure", "Triceps skin fold thickness", "Two hour serum insulin", "Body mass index", "Diabetes pedigree function", "Age"]
# We prepare the train sets
class_names = np.array(pima_class_names)
feature_names = np.array(attributes.names())
df=pd.DataFrame(train_data_value)
df.columns = feature_names
target = df.pop('class')
target_onehot = pd.get_dummies(target)[b'2.0']
x_tr = df
y_tr = target_onehot
y_tr_num = [float(out_class) for out_class in y_tr] # we will need the numerical outputs to simulate regression
# And the test sets too
df_test=pd.DataFrame(test_data_value)
x_test = df_test
target_test = df_test.pop('class')
target_test_onehot = pd.get_dummies(target_test)[b'2.0']
y_test = target_test_onehot
y_test_num = [float(out_class) for out_class in y_test] # likewise
# Print some general information of the dataset
print(f'Class names for the PIMA Dataset:\n\t{", ".join(list(class_names))}\nFeature names for the PIMA Dataset:\n\t{", ".join(list(feature_names))}')
Class names for the PIMA Dataset: tested_negative, tested_positive Feature names for the PIMA Dataset: Number_of_times_pregnant, Plasma_glucose_concentration, Diastolic_blood_pressure, Triceps_skin_fold_thickness, 2_Hour_serum_insulin, Body_mass_index, Diabetes_pedigree_function, Age, class
## this code is developed by Jose Maria Alonso-Moral
# Check how the tabular data looks
df.head(n=10)
| Number_of_times_pregnant | Plasma_glucose_concentration | Diastolic_blood_pressure | Triceps_skin_fold_thickness | 2_Hour_serum_insulin | Body_mass_index | Diabetes_pedigree_function | Age | |
|---|---|---|---|---|---|---|---|---|
| 0 | 5.0 | 103.0 | 108.0 | 37.0 | 0.0 | 39.2 | 0.305 | 65.0 |
| 1 | 7.0 | 103.0 | 66.0 | 32.0 | 0.0 | 39.1 | 0.344 | 31.0 |
| 2 | 10.0 | 101.0 | 76.0 | 48.0 | 180.0 | 32.9 | 0.171 | 63.0 |
| 3 | 5.0 | 139.0 | 64.0 | 35.0 | 140.0 | 28.6 | 0.411 | 26.0 |
| 4 | 0.0 | 74.0 | 52.0 | 10.0 | 36.0 | 27.8 | 0.269 | 22.0 |
| 5 | 5.0 | 136.0 | 82.0 | 0.0 | 0.0 | 0.0 | 0.640 | 69.0 |
| 6 | 3.0 | 176.0 | 86.0 | 27.0 | 156.0 | 33.3 | 1.154 | 52.0 |
| 7 | 3.0 | 128.0 | 72.0 | 25.0 | 190.0 | 32.4 | 0.549 | 27.0 |
| 8 | 7.0 | 184.0 | 84.0 | 33.0 | 0.0 | 35.5 | 0.355 | 41.0 |
| 9 | 3.0 | 99.0 | 62.0 | 19.0 | 74.0 | 21.8 | 0.279 | 26.0 |
## this code is developed by Jose Maria Alonso-Moral
# Take a look at the output class distribution for the test set. Remember that 1.0 means tested_negative and 2.0 means tested_positive.
df_test=pd.DataFrame(test_data_value)
plt.figure(figsize=(12,6))
sns.countplot(x='class', data=df_test)
<Axes: xlabel='class', ylabel='count'>
3.2. Model building¶
Now, we build some ML models over the dataset we just imported: two Decision Trees (which are interpretable-by-design) and a Random Forest (which is opaque). We will also perform a simple validation of their performance.
Some notes about the model creation:
- We are using the default parameters in most situations; changing them will change the performance and, potentially, the interpretations you can extract from the trees. For the purposes of this practical session, we will not go very in-depth about it other than showing the differences that appear between an unlimited Decision Tree and one that has a hard-cap on its depth.
- You will see below that we train Regressor equivalents of all three models. This is done to show that the approach we follow during this section works just fine for that setting too, but it is otherwise not correct. As mentioned on the previous section, the dataset we are employing is a binary classification dataset, and even though that type of problem can be approximated by a regression on the [0.0, 1.0] range, it is generally improper to utilize a dataset on a context it was not designed for. Always be mindful of the type of data you are working with!
You can check official documentation of the sklearn models we are going to use here: DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor and RandomForestRegressor.
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# We fit a Decision Tree, a Decision Tree of limited depth, and a Random Forest.
dtc = tree.DecisionTreeClassifier()
dtc.fit(x_tr, y_tr)
dtc5 = tree.DecisionTreeClassifier(max_depth=5)
dtc5.fit(x_tr, y_tr)
rfc = RandomForestClassifier(n_estimators=1000)
rfc.fit(x_tr, y_tr)
models = [dtc, dtc5, rfc]
models_names = ['TREE', 'TREE5', 'RF']
# Cross-validation for the models just fit above
print("1) Cross-validation (over training data)")
scorings = ['accuracy', 'f1'] # For binary classification
nF= 5
for model, model_name in zip(models, models_names):
cv_results = cross_validate(model, x_tr, y_tr, cv=nF,
scoring=scorings,
return_train_score=False)
print(f'\n\t{model_name}:')
print(f'\t\tCorrect Classification Rate [Average (St. Dev)] = {np.mean(cv_results["test_accuracy"]):.3f} ({np.std(cv_results["test_accuracy"]):.3f})')
print(f'\t\tF-Score [Average (St. Dev)] = {np.mean(cv_results["test_f1"]):.3f} ({np.std(cv_results["test_f1"]):.3f})')
# Test with unknown instances
print("\n\n2) Test (with previously unseen data)")
models_acc=[]
target_names = ['class 1', 'class 2']
for model, model_name in zip(models, models_names):
sc = round(model.score(x_test, y_test), 3) # we round the accuracy to 3 decimals and append it to the list of model accuracies for the test set
models_acc.append(sc)
# Mean accuracy of self.predict(x_test) wrt y_test
print(f'\n\t{model_name}:')
print(f'\t\tCorrect Classification Rate: {models_acc[-1]:.3f}')
y_pred= model.predict(x_test)
print(classification_report(y_test, model.predict(x_test), target_names=target_names))
1) Cross-validation (over training data)
TREE:
Correct Classification Rate [Average (St. Dev)] = 0.690 (0.034)
F-Score [Average (St. Dev)] = 0.550 (0.031)
TREE5:
Correct Classification Rate [Average (St. Dev)] = 0.748 (0.024)
F-Score [Average (St. Dev)] = 0.618 (0.026)
RF:
Correct Classification Rate [Average (St. Dev)] = 0.758 (0.031)
F-Score [Average (St. Dev)] = 0.625 (0.057)
2) Test (with previously unseen data)
TREE:
Correct Classification Rate: 0.688
precision recall f1-score support
class 1 0.78 0.72 0.75 50
class 2 0.55 0.63 0.59 27
accuracy 0.69 77
macro avg 0.67 0.67 0.67 77
weighted avg 0.70 0.69 0.69 77
TREE5:
Correct Classification Rate: 0.727
precision recall f1-score support
class 1 0.78 0.80 0.79 50
class 2 0.62 0.59 0.60 27
accuracy 0.73 77
macro avg 0.70 0.70 0.70 77
weighted avg 0.73 0.73 0.73 77
RF:
Correct Classification Rate: 0.805
precision recall f1-score support
class 1 0.89 0.80 0.84 50
class 2 0.69 0.81 0.75 27
accuracy 0.81 77
macro avg 0.79 0.81 0.79 77
weighted avg 0.82 0.81 0.81 77
## this code is developed by Pablo Miguel Perez-Ferreiro
# Keep in mind we need to use the numerical version of the output for regression equivalents
dtr = tree.DecisionTreeRegressor()
dtr.fit(x_tr, y_tr_num)
dtr5 = tree.DecisionTreeRegressor(max_depth=5)
dtr5.fit(x_tr, y_tr_num)
rfr = RandomForestRegressor(n_estimators=1000)
rfr.fit(x_tr, y_tr_num)
models_r = [dtr, dtr5, rfr]
models_r_names = ['TREE-R', 'TREE5-R', 'RF-R']
# We don't validate for Regressors; it can be done, but it does not fit this problem.
3.3. Visualization¶
An advantage of the interpretable models we just trained (i.e., the Decision Trees) is that they can be easily visualized. When DTs are reasonably shallow, this may suffice to explain:
- Their global behavior, as the general structure of the tree provides a certain understanding of the tree's priorities when it comes to classifying instances. In this sense, a Decision Tree always gauges the importance of its features (implicitly): consider that most algorithms for decision tree creation build the decision splits based on information gain measures, which means that more informative features will be placed higher on the tree's branches.
- Their local behavior, as any given instance will follow a certain path until reaching a leaf node. This path of conditions can be understood as an ad-hoc rule that justifies the instance's classification.
These two properties are limited, however, by the features of the tree holding a degree of meaning, i.e., highly complex features may not be easily understood, thus limiting their usefulness as an explanatory tool. Plain visualization can also be unwieldy when DTs are very deep, thus nesting many splits. You can check that effect by comparing the graphical representations below (unlimited depth vs. limited depth); in them each node lists:
- The split condition. Leaf nodes omit this, as there are no more conditions to check.
- The Gini index, which measures inequality and is the measure that guides splits. This can be changed for entropy during model creation; check the documentation previously linked.
- The number of samples that have 'reached' that node.
- The distribution of classes amongst those samples: [A, B] represents [tested_negative, tested_positive].
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Visualizing entire decision tree
print(export_text(dtc, feature_names=pima_fnames))
dot_data = tree.export_graphviz(dtc, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames)
graph = graphviz.Source(dot_data)
graph#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 143.50 | |--- Body mass index <= 27.35 | | |--- Plasma glucose concentration <= 105.50 | | | |--- class: False | | |--- Plasma glucose concentration > 105.50 | | | |--- Body mass index <= 9.80 | | | | |--- Number of times pregnant <= 7.00 | | | | | |--- class: False | | | | |--- Number of times pregnant > 7.00 | | | | | |--- class: True | | | |--- Body mass index > 9.80 | | | | |--- Plasma glucose concentration <= 108.50 | | | | | |--- Age <= 50.00 | | | | | | |--- Diastolic blood_pressure <= 65.00 | | | | | | | |--- Two hour serum insulin <= 52.00 | | | | | | | | |--- Number of times pregnant <= 1.50 | | | | | | | | | |--- class: False | | | | | | | | |--- Number of times pregnant > 1.50 | | | | | | | | | |--- class: True | | | | | | | |--- Two hour serum insulin > 52.00 | | | | | | | | |--- class: False | | | | | | |--- Diastolic blood_pressure > 65.00 | | | | | | | |--- class: False | | | | | |--- Age > 50.00 | | | | | | |--- class: True | | | | |--- Plasma glucose concentration > 108.50 | | | | | |--- Plasma glucose concentration <= 128.50 | | | | | | |--- class: False | | | | | |--- Plasma glucose concentration > 128.50 | | | | | | |--- Number of times pregnant <= 2.50 | | | | | | | |--- class: False | | | | | | |--- Number of times pregnant > 2.50 | | | | | | | |--- Number of times pregnant <= 4.50 | | | | | | | | |--- Body mass index <= 23.45 | | | | | | | | | |--- class: False | | | | | | | | |--- Body mass index > 23.45 | | | | | | | | | |--- class: True | | | | | | | |--- Number of times pregnant > 4.50 | | | | | | | | |--- class: False | |--- Body mass index > 27.35 | | |--- Plasma glucose concentration <= 99.50 | | | |--- Age <= 25.50 | | | | |--- Diastolic blood_pressure <= 83.50 | | | | | |--- class: False | | | | |--- Diastolic blood_pressure > 83.50 | | | | | |--- Diabetes pedigree function <= 0.28 | | | | | | |--- class: True | | | | | |--- Diabetes pedigree function > 0.28 | | | | | | |--- class: False | | | |--- Age > 25.50 | | | | |--- Age <= 27.50 | | | | | |--- Two hour serum insulin <= 35.00 | | | | | | |--- Triceps skin fold thickness <= 37.50 | | | | | | | |--- class: False | | | | | | |--- Triceps skin fold thickness > 37.50 | | | | | | | |--- class: True | | | | | |--- Two hour serum insulin > 35.00 | | | | | | |--- class: True | | | | |--- Age > 27.50 | | | | | |--- Plasma glucose concentration <= 28.50 | | | | | | |--- class: True | | | | | |--- Plasma glucose concentration > 28.50 | | | | | | |--- Diabetes pedigree function <= 0.80 | | | | | | | |--- Age <= 42.50 | | | | | | | | |--- Diabetes pedigree function <= 0.17 | | | | | | | | | |--- Triceps skin fold thickness <= 18.50 | | | | | | | | | | |--- class: True | | | | | | | | | |--- Triceps skin fold thickness > 18.50 | | | | | | | | | | |--- class: False | | | | | | | | |--- Diabetes pedigree function > 0.17 | | | | | | | | | |--- class: False | | | | | | | |--- Age > 42.50 | | | | | | | | |--- Triceps skin fold thickness <= 24.00 | | | | | | | | | |--- class: False | | | | | | | | |--- Triceps skin fold thickness > 24.00 | | | | | | | | | |--- Triceps skin fold thickness <= 31.50 | | | | | | | | | | |--- class: True | | | | | | | | | |--- Triceps skin fold thickness > 31.50 | | | | | | | | | | |--- class: False | | | | | | |--- Diabetes pedigree function > 0.80 | | | | | | | |--- Number of times pregnant <= 3.00 | | | | | | | | |--- class: False | | | | | | | |--- Number of times pregnant > 3.00 | | | | | | | | |--- class: True | | |--- Plasma glucose concentration > 99.50 | | | |--- Age <= 30.50 | | | | |--- Diastolic blood_pressure <= 22.00 | | | | | |--- Plasma glucose concentration <= 117.00 | | | | | | |--- class: False | | | | | |--- Plasma glucose concentration > 117.00 | | | | | | |--- class: True | | | | |--- Diastolic blood_pressure > 22.00 | | | | | |--- Body mass index <= 45.40 | | | | | | |--- Diabetes pedigree function <= 0.51 | | | | | | | |--- Number of times pregnant <= 7.00 | | | | | | | | |--- Plasma glucose concentration <= 138.50 | | | | | | | | | |--- Two hour serum insulin <= 22.00 | | | | | | | | | | |--- Body mass index <= 27.70 | | | | | | | | | | | |--- class: True | | | | | | | | | | |--- Body mass index > 27.70 | | | | | | | | | | | |--- truncated branch of depth 6 | | | | | | | | | |--- Two hour serum insulin > 22.00 | | | | | | | | | | |--- Plasma glucose concentration <= 133.00 | | | | | | | | | | | |--- class: False | | | | | | | | | | |--- Plasma glucose concentration > 133.00 | | | | | | | | | | | |--- truncated branch of depth 2 | | | | | | | | |--- Plasma glucose concentration > 138.50 | | | | | | | | | |--- Two hour serum insulin <= 65.00 | | | | | | | | | | |--- class: False | | | | | | | | | |--- Two hour serum insulin > 65.00 | | | | | | | | | | |--- Diastolic blood_pressure <= 64.50 | | | | | | | | | | | |--- class: False | | | | | | | | | | |--- Diastolic blood_pressure > 64.50 | | | | | | | | | | | |--- class: True | | | | | | | |--- Number of times pregnant > 7.00 | | | | | | | | |--- class: True | | | | | | |--- Diabetes pedigree function > 0.51 | | | | | | | |--- Diabetes pedigree function <= 0.55 | | | | | | | | |--- Body mass index <= 38.25 | | | | | | | | | |--- Diabetes pedigree function <= 0.53 | | | | | | | | | | |--- Plasma glucose concentration <= 110.50 | | | | | | | | | | | |--- class: True | | | | | | | | | | |--- Plasma glucose concentration > 110.50 | | | | | | | | | | | |--- class: False | | | | | | | | | |--- Diabetes pedigree function > 0.53 | | | | | | | | | | |--- class: True | | | | | | | | |--- Body mass index > 38.25 | | | | | | | | | |--- class: False | | | | | | | |--- Diabetes pedigree function > 0.55 | | | | | | | | |--- Body mass index <= 32.70 | | | | | | | | | |--- Diabetes pedigree function <= 1.07 | | | | | | | | | | |--- class: False | | | | | | | | | |--- Diabetes pedigree function > 1.07 | | | | | | | | | | |--- Diabetes pedigree function <= 1.50 | | | | | | | | | | | |--- class: True | | | | | | | | | | |--- Diabetes pedigree function > 1.50 | | | | | | | | | | | |--- class: False | | | | | | | | |--- Body mass index > 32.70 | | | | | | | | | |--- Diastolic blood_pressure <= 69.00 | | | | | | | | | | |--- Body mass index <= 34.85 | | | | | | | | | | | |--- class: True | | | | | | | | | | |--- Body mass index > 34.85 | | | | | | | | | | | |--- truncated branch of depth 3 | | | | | | | | | |--- Diastolic blood_pressure > 69.00 | | | | | | | | | | |--- Age <= 21.50 | | | | | | | | | | | |--- class: True | | | | | | | | | | |--- Age > 21.50 | | | | | | | | | | | |--- truncated branch of depth 4 | | | | | |--- Body mass index > 45.40 | | | | | | |--- Age <= 22.50 | | | | | | | |--- class: False | | | | | | |--- Age > 22.50 | | | | | | | |--- class: True | | | |--- Age > 30.50 | | | | |--- Diabetes pedigree function <= 0.53 | | | | | |--- Age <= 47.50 | | | | | | |--- Diabetes pedigree function <= 0.13 | | | | | | | |--- class: False | | | | | | |--- Diabetes pedigree function > 0.13 | | | | | | | |--- Diabetes pedigree function <= 0.49 | | | | | | | | |--- Body mass index <= 31.15 | | | | | | | | | |--- Number of times pregnant <= 6.50 | | | | | | | | | | |--- Diabetes pedigree function <= 0.41 | | | | | | | | | | | |--- truncated branch of depth 3 | | | | | | | | | | |--- Diabetes pedigree function > 0.41 | | | | | | | | | | | |--- class: False | | | | | | | | | |--- Number of times pregnant > 6.50 | | | | | | | | | | |--- class: True | | | | | | | | |--- Body mass index > 31.15 | | | | | | | | | |--- Diastolic blood_pressure <= 97.00 | | | | | | | | | | |--- Plasma glucose concentration <= 136.50 | | | | | | | | | | | |--- truncated branch of depth 9 | | | | | | | | | | |--- Plasma glucose concentration > 136.50 | | | | | | | | | | | |--- class: False | | | | | | | | | |--- Diastolic blood_pressure > 97.00 | | | | | | | | | | |--- class: True | | | | | | | |--- Diabetes pedigree function > 0.49 | | | | | | | | |--- class: False | | | | | |--- Age > 47.50 | | | | | | |--- Number of times pregnant <= 1.50 | | | | | | | |--- class: True | | | | | | |--- Number of times pregnant > 1.50 | | | | | | | |--- Diabetes pedigree function <= 0.14 | | | | | | | | |--- class: True | | | | | | | |--- Diabetes pedigree function > 0.14 | | | | | | | | |--- class: False | | | | |--- Diabetes pedigree function > 0.53 | | | | | |--- Number of times pregnant <= 7.50 | | | | | | |--- Diastolic blood_pressure <= 69.00 | | | | | | | |--- Body mass index <= 28.10 | | | | | | | | |--- class: True | | | | | | | |--- Body mass index > 28.10 | | | | | | | | |--- Diabetes pedigree function <= 1.84 | | | | | | | | | |--- class: False | | | | | | | | |--- Diabetes pedigree function > 1.84 | | | | | | | | | |--- class: True | | | | | | |--- Diastolic blood_pressure > 69.00 | | | | | | | |--- Body mass index <= 39.65 | | | | | | | | |--- Diabetes pedigree function <= 0.56 | | | | | | | | | |--- Diastolic blood_pressure <= 72.50 | | | | | | | | | | |--- class: True | | | | | | | | | |--- Diastolic blood_pressure > 72.50 | | | | | | | | | | |--- class: False | | | | | | | | |--- Diabetes pedigree function > 0.56 | | | | | | | | | |--- class: True | | | | | | | |--- Body mass index > 39.65 | | | | | | | | |--- Two hour serum insulin <= 146.00 | | | | | | | | | |--- class: False | | | | | | | | |--- Two hour serum insulin > 146.00 | | | | | | | | | |--- class: True | | | | | |--- Number of times pregnant > 7.50 | | | | | | |--- class: True |--- Plasma glucose concentration > 143.50 | |--- Plasma glucose concentration <= 154.50 | | |--- Diabetes pedigree function <= 0.33 | | | |--- Diabetes pedigree function <= 0.18 | | | | |--- Plasma glucose concentration <= 151.00 | | | | | |--- class: True | | | | |--- Plasma glucose concentration > 151.00 | | | | | |--- class: False | | | |--- Diabetes pedigree function > 0.18 | | | | |--- Age <= 67.50 | | | | | |--- Number of times pregnant <= 6.50 | | | | | | |--- class: False | | | | | |--- Number of times pregnant > 6.50 | | | | | | |--- Diastolic blood_pressure <= 79.00 | | | | | | | |--- class: True | | | | | | |--- Diastolic blood_pressure > 79.00 | | | | | | | |--- class: False | | | | |--- Age > 67.50 | | | | | |--- class: True | | |--- Diabetes pedigree function > 0.33 | | | |--- Age <= 31.50 | | | | |--- Diabetes pedigree function <= 0.37 | | | | | |--- class: True | | | | |--- Diabetes pedigree function > 0.37 | | | | | |--- Plasma glucose concentration <= 145.00 | | | | | | |--- Triceps skin fold thickness <= 36.50 | | | | | | | |--- class: True | | | | | | |--- Triceps skin fold thickness > 36.50 | | | | | | | |--- class: False | | | | | |--- Plasma glucose concentration > 145.00 | | | | | | |--- class: False | | | |--- Age > 31.50 | | | | |--- Plasma glucose concentration <= 152.50 | | | | | |--- Two hour serum insulin <= 313.50 | | | | | | |--- class: True | | | | | |--- Two hour serum insulin > 313.50 | | | | | | |--- class: False | | | | |--- Plasma glucose concentration > 152.50 | | | | | |--- Two hour serum insulin <= 63.00 | | | | | | |--- class: True | | | | | |--- Two hour serum insulin > 63.00 | | | | | | |--- class: False | |--- Plasma glucose concentration > 154.50 | | |--- Body mass index <= 29.85 | | | |--- Age <= 25.50 | | | | |--- class: False | | | |--- Age > 25.50 | | | | |--- Age <= 61.00 | | | | | |--- Body mass index <= 27.00 | | | | | | |--- class: True | | | | | |--- Body mass index > 27.00 | | | | | | |--- Age <= 36.50 | | | | | | | |--- class: True | | | | | | |--- Age > 36.50 | | | | | | | |--- class: False | | | | |--- Age > 61.00 | | | | | |--- class: False | | |--- Body mass index > 29.85 | | | |--- Diabetes pedigree function <= 0.13 | | | | |--- class: False | | | |--- Diabetes pedigree function > 0.13 | | | | |--- Age <= 44.00 | | | | | |--- Two hour serum insulin <= 661.50 | | | | | | |--- Number of times pregnant <= 9.50 | | | | | | | |--- Diabetes pedigree function <= 0.31 | | | | | | | | |--- Diabetes pedigree function <= 0.29 | | | | | | | | | |--- Number of times pregnant <= 0.50 | | | | | | | | | | |--- Plasma glucose concentration <= 173.00 | | | | | | | | | | | |--- class: False | | | | | | | | | | |--- Plasma glucose concentration > 173.00 | | | | | | | | | | | |--- class: True | | | | | | | | | |--- Number of times pregnant > 0.50 | | | | | | | | | | |--- class: True | | | | | | | | |--- Diabetes pedigree function > 0.29 | | | | | | | | | |--- class: False | | | | | | | |--- Diabetes pedigree function > 0.31 | | | | | | | | |--- class: True | | | | | | |--- Number of times pregnant > 9.50 | | | | | | | |--- Body mass index <= 34.35 | | | | | | | | |--- class: True | | | | | | | |--- Body mass index > 34.35 | | | | | | | | |--- class: False | | | | | |--- Two hour serum insulin > 661.50 | | | | | | |--- class: False | | | | |--- Age > 44.00 | | | | | |--- Diabetes pedigree function <= 1.16 | | | | | | |--- Body mass index <= 34.45 | | | | | | | |--- Body mass index <= 32.20 | | | | | | | | |--- class: True | | | | | | | |--- Body mass index > 32.20 | | | | | | | | |--- Two hour serum insulin <= 111.50 | | | | | | | | | |--- class: False | | | | | | | | |--- Two hour serum insulin > 111.50 | | | | | | | | | |--- Diastolic blood_pressure <= 65.00 | | | | | | | | | | |--- class: True | | | | | | | | | |--- Diastolic blood_pressure > 65.00 | | | | | | | | | | |--- Plasma glucose concentration <= 170.50 | | | | | | | | | | | |--- class: False | | | | | | | | | | |--- Plasma glucose concentration > 170.50 | | | | | | | | | | | |--- class: True | | | | | | |--- Body mass index > 34.45 | | | | | | | |--- class: True | | | | | |--- Diabetes pedigree function > 1.16 | | | | | | |--- Diastolic blood_pressure <= 77.00 | | | | | | | |--- class: True | | | | | | |--- Diastolic blood_pressure > 77.00 | | | | | | | |--- class: False
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Visualizing the reduced tree
print(export_text(dtc5, feature_names=pima_fnames))
dot_data5 = tree.export_graphviz(dtc5, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames)
graph5 = graphviz.Source(dot_data5)
graph5#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 143.50 | |--- Body mass index <= 27.35 | | |--- Plasma glucose concentration <= 105.50 | | | |--- class: False | | |--- Plasma glucose concentration > 105.50 | | | |--- Body mass index <= 9.80 | | | | |--- Number of times pregnant <= 7.00 | | | | | |--- class: False | | | | |--- Number of times pregnant > 7.00 | | | | | |--- class: True | | | |--- Body mass index > 9.80 | | | | |--- Plasma glucose concentration <= 108.50 | | | | | |--- class: False | | | | |--- Plasma glucose concentration > 108.50 | | | | | |--- class: False | |--- Body mass index > 27.35 | | |--- Plasma glucose concentration <= 99.50 | | | |--- Age <= 25.50 | | | | |--- Diastolic blood_pressure <= 83.50 | | | | | |--- class: False | | | | |--- Diastolic blood_pressure > 83.50 | | | | | |--- class: False | | | |--- Age > 25.50 | | | | |--- Age <= 27.50 | | | | | |--- class: True | | | | |--- Age > 27.50 | | | | | |--- class: False | | |--- Plasma glucose concentration > 99.50 | | | |--- Age <= 30.50 | | | | |--- Diastolic blood_pressure <= 22.00 | | | | | |--- class: True | | | | |--- Diastolic blood_pressure > 22.00 | | | | | |--- class: False | | | |--- Age > 30.50 | | | | |--- Diabetes pedigree function <= 0.53 | | | | | |--- class: False | | | | |--- Diabetes pedigree function > 0.53 | | | | | |--- class: True |--- Plasma glucose concentration > 143.50 | |--- Plasma glucose concentration <= 154.50 | | |--- Diabetes pedigree function <= 0.33 | | | |--- Diabetes pedigree function <= 0.18 | | | | |--- Triceps skin fold thickness <= 28.50 | | | | | |--- class: True | | | | |--- Triceps skin fold thickness > 28.50 | | | | | |--- class: False | | | |--- Diabetes pedigree function > 0.18 | | | | |--- Age <= 67.50 | | | | | |--- class: False | | | | |--- Age > 67.50 | | | | | |--- class: True | | |--- Diabetes pedigree function > 0.33 | | | |--- Age <= 31.50 | | | | |--- Diabetes pedigree function <= 0.37 | | | | | |--- class: True | | | | |--- Diabetes pedigree function > 0.37 | | | | | |--- class: False | | | |--- Age > 31.50 | | | | |--- Plasma glucose concentration <= 152.50 | | | | | |--- class: True | | | | |--- Plasma glucose concentration > 152.50 | | | | | |--- class: False | |--- Plasma glucose concentration > 154.50 | | |--- Body mass index <= 29.85 | | | |--- Age <= 25.50 | | | | |--- class: False | | | |--- Age > 25.50 | | | | |--- Age <= 61.00 | | | | | |--- class: True | | | | |--- Age > 61.00 | | | | | |--- class: False | | |--- Body mass index > 29.85 | | | |--- Diabetes pedigree function <= 0.13 | | | | |--- class: False | | | |--- Diabetes pedigree function > 0.13 | | | | |--- Age <= 44.00 | | | | | |--- class: True | | | | |--- Age > 44.00 | | | | | |--- class: True
We can also plot the regressor trees. You'll see their format is mostly equivalent to that above, although they use squared error instead of Gini index (regression vs. classification, as said) and report the value that would be assigned for a data instance that ended up on that node.
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Visualizing the reduced regressor tree
print(export_text(dtr5, feature_names=pima_fnames))
dot_data5r = tree.export_graphviz(dtr5, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames)
graph5r = graphviz.Source(dot_data5r)
graph5r#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 143.50 | |--- Body mass index <= 27.35 | | |--- Plasma glucose concentration <= 105.50 | | | |--- value: [0.00] | | |--- Plasma glucose concentration > 105.50 | | | |--- Body mass index <= 9.80 | | | | |--- Number of times pregnant <= 7.00 | | | | | |--- value: [0.00] | | | | |--- Number of times pregnant > 7.00 | | | | | |--- value: [1.00] | | | |--- Body mass index > 9.80 | | | | |--- Plasma glucose concentration <= 108.50 | | | | | |--- value: [0.23] | | | | |--- Plasma glucose concentration > 108.50 | | | | | |--- value: [0.05] | |--- Body mass index > 27.35 | | |--- Plasma glucose concentration <= 99.50 | | | |--- Age <= 25.50 | | | | |--- Diastolic blood_pressure <= 83.50 | | | | | |--- value: [0.00] | | | | |--- Diastolic blood_pressure > 83.50 | | | | | |--- value: [0.33] | | | |--- Age > 25.50 | | | | |--- Age <= 27.50 | | | | | |--- value: [0.67] | | | | |--- Age > 27.50 | | | | | |--- value: [0.17] | | |--- Plasma glucose concentration > 99.50 | | | |--- Age <= 30.50 | | | | |--- Diastolic blood_pressure <= 22.00 | | | | | |--- value: [0.86] | | | | |--- Diastolic blood_pressure > 22.00 | | | | | |--- value: [0.25] | | | |--- Age > 30.50 | | | | |--- Diabetes pedigree function <= 0.53 | | | | | |--- value: [0.45] | | | | |--- Diabetes pedigree function > 0.53 | | | | | |--- value: [0.80] |--- Plasma glucose concentration > 143.50 | |--- Plasma glucose concentration <= 154.50 | | |--- Diabetes pedigree function <= 0.33 | | | |--- Diabetes pedigree function <= 0.18 | | | | |--- Plasma glucose concentration <= 151.00 | | | | | |--- value: [1.00] | | | | |--- Plasma glucose concentration > 151.00 | | | | | |--- value: [0.00] | | | |--- Diabetes pedigree function > 0.18 | | | | |--- Age <= 67.50 | | | | | |--- value: [0.08] | | | | |--- Age > 67.50 | | | | | |--- value: [1.00] | | |--- Diabetes pedigree function > 0.33 | | | |--- Age <= 31.50 | | | | |--- Diabetes pedigree function <= 0.37 | | | | | |--- value: [1.00] | | | | |--- Diabetes pedigree function > 0.37 | | | | | |--- value: [0.12] | | | |--- Age > 31.50 | | | | |--- Plasma glucose concentration <= 152.50 | | | | | |--- value: [0.94] | | | | |--- Plasma glucose concentration > 152.50 | | | | | |--- value: [0.33] | |--- Plasma glucose concentration > 154.50 | | |--- Body mass index <= 29.85 | | | |--- Age <= 25.50 | | | | |--- value: [0.00] | | | |--- Age > 25.50 | | | | |--- Age <= 61.00 | | | | | |--- value: [0.80] | | | | |--- Age > 61.00 | | | | | |--- value: [0.00] | | |--- Body mass index > 29.85 | | | |--- Diabetes pedigree function <= 0.13 | | | | |--- value: [0.00] | | | |--- Diabetes pedigree function > 0.13 | | | | |--- Age <= 44.00 | | | | | |--- value: [0.93] | | | | |--- Age > 44.00 | | | | | |--- value: [0.76]
Be aware that, convenient as it may be, this type of interpretation can't be performed for all (or most) ML approaches. For instance, the Random Forest we just trained can't be visualized like this, and in fact is not interpretable by itself. However, Random Forests are usually very strong in terms of predictive performance, and thus ideally, we would like to make them at least a bit understandable. We will need a different tool for this.
Exercise I1.1¶
Try to build some alternate versions of the Decision Trees we obtained, by changing the parameters on the constructor call. You don't need to do this for Regressors, so refer to this documentation. You don't need to be exhaustive: trying two or three is enough. You may keep limiting their depth to keep their interpretation manageable (but you don't need to).
Plot these new decision trees and compare them to the two trees (dtc and dtc5) given as illustrative examples in the previous cells. Are they significantly different? Extract some conclusions as to why/why not, and summarize a few key insights onto the diabetes prediction problem you can distill from the trends shown by your set of trees.
# this code is developed by Miguel Leal and Gian Paolo Bulleddu
#==========================================================
# We will now fit some alternative Decision Trees models with different hyperparameters,
# so that we can check how the interpretability metrics and the performance metrics change with the new hyperparameters.
# We will fit a Decision Tree with the same maximum depth as before, but using the 'entropy' criterion instead of 'gini',
# and another Decision Tree with the same maximum depth but with a minimum number of samples per leaf of 3 , this will reduce the complexity of the tree,
# and will increase its interpretability.
#==========================================================
dtcv1 = tree.DecisionTreeClassifier(max_depth=5, criterion='entropy')
dtcv1.fit(x_tr, y_tr)
dtcv2 = tree.DecisionTreeClassifier(max_depth=5, min_samples_leaf=3)
dtcv2.fit(x_tr, y_tr)
models_alt = [dtcv1, dtcv2]
models_names_alt = ['TREEv1', 'TREEv2']
# Cross-validation for the models just fit above
print("1) Cross-validation (over training data)")
scorings = ['accuracy', 'f1'] # For binary classification
nF= 5
for model, model_name in zip(models_alt, models_names_alt):
cv_results = cross_validate(model, x_tr, y_tr, cv=nF,
scoring=scorings,
return_train_score=False)
print(f'\n\t{model_name}:')
print(f'\t\tCorrect Classification Rate [Average (St. Dev)] = {np.mean(cv_results["test_accuracy"]):.3f} ({np.std(cv_results["test_accuracy"]):.3f})')
print(f'\t\tF-Score [Average (St. Dev)] = {np.mean(cv_results["test_f1"]):.3f} ({np.std(cv_results["test_f1"]):.3f})')
# Test with unknown instances
print("\n\n2) Test (with previously unseen data)")
models_acc_alt=[]
target_names = ['class 1', 'class 2']
for model, model_name in zip(models_alt, models_names_alt):
sc = round(model.score(x_test, y_test), 3) # we round the accuracy to 3 decimals and append it to the list of model accuracies for the test set
models_acc_alt.append(sc)
# Mean accuracy of self.predict(x_test) wrt y_test
print(f'\n\t{model_name}:')
print(f'\t\tCorrect Classification Rate: {models_acc_alt[-1]:.3f}')
y_pred= model.predict(x_test)
print(classification_report(y_test, model.predict(x_test), target_names=target_names))
1) Cross-validation (over training data)
TREEv1:
Correct Classification Rate [Average (St. Dev)] = 0.729 (0.051)
F-Score [Average (St. Dev)] = 0.576 (0.087)
TREEv2:
Correct Classification Rate [Average (St. Dev)] = 0.737 (0.029)
F-Score [Average (St. Dev)] = 0.591 (0.046)
2) Test (with previously unseen data)
TREEv1:
Correct Classification Rate: 0.805
precision recall f1-score support
class 1 0.86 0.84 0.85 50
class 2 0.71 0.74 0.73 27
accuracy 0.81 77
macro avg 0.79 0.79 0.79 77
weighted avg 0.81 0.81 0.81 77
TREEv2:
Correct Classification Rate: 0.727
precision recall f1-score support
class 1 0.78 0.80 0.79 50
class 2 0.62 0.59 0.60 27
accuracy 0.73 77
macro avg 0.70 0.70 0.70 77
weighted avg 0.73 0.73 0.73 77
# this code is developed by Miguel Leal and Gian Paolo Bulleddu
# Visualizing the alternative reduced tree with 'entropy' criterion
print(export_text(dtcv1, feature_names=pima_fnames))
dot_datav1 = tree.export_graphviz(dtcv1, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames)
graphv1 = graphviz.Source(dot_datav1)
graphv1#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 127.50 | |--- Body mass index <= 26.45 | | |--- Body mass index <= 9.10 | | | |--- Number of times pregnant <= 7.50 | | | | |--- class: False | | | |--- Number of times pregnant > 7.50 | | | | |--- class: True | | |--- Body mass index > 9.10 | | | |--- Diabetes pedigree function <= 0.68 | | | | |--- class: False | | | |--- Diabetes pedigree function > 0.68 | | | | |--- Diabetes pedigree function <= 0.71 | | | | | |--- class: True | | | | |--- Diabetes pedigree function > 0.71 | | | | | |--- class: False | |--- Body mass index > 26.45 | | |--- Age <= 28.50 | | | |--- Body mass index <= 30.95 | | | | |--- Number of times pregnant <= 7.00 | | | | | |--- class: False | | | | |--- Number of times pregnant > 7.00 | | | | | |--- class: True | | | |--- Body mass index > 30.95 | | | | |--- Diastolic blood_pressure <= 37.00 | | | | | |--- class: True | | | | |--- Diastolic blood_pressure > 37.00 | | | | | |--- class: False | | |--- Age > 28.50 | | | |--- Plasma glucose concentration <= 99.50 | | | | |--- Plasma glucose concentration <= 28.50 | | | | | |--- class: True | | | | |--- Plasma glucose concentration > 28.50 | | | | | |--- class: False | | | |--- Plasma glucose concentration > 99.50 | | | | |--- Diabetes pedigree function <= 0.56 | | | | | |--- class: False | | | | |--- Diabetes pedigree function > 0.56 | | | | | |--- class: True |--- Plasma glucose concentration > 127.50 | |--- Plasma glucose concentration <= 154.50 | | |--- Body mass index <= 28.85 | | | |--- Number of times pregnant <= 1.50 | | | | |--- class: False | | | |--- Number of times pregnant > 1.50 | | | | |--- Body mass index <= 23.45 | | | | | |--- class: False | | | | |--- Body mass index > 23.45 | | | | | |--- class: False | | |--- Body mass index > 28.85 | | | |--- Diabetes pedigree function <= 0.44 | | | | |--- Body mass index <= 41.80 | | | | | |--- class: False | | | | |--- Body mass index > 41.80 | | | | | |--- class: True | | | |--- Diabetes pedigree function > 0.44 | | | | |--- Age <= 30.00 | | | | | |--- class: False | | | | |--- Age > 30.00 | | | | | |--- class: True | |--- Plasma glucose concentration > 154.50 | | |--- Body mass index <= 29.85 | | | |--- Age <= 25.50 | | | | |--- class: False | | | |--- Age > 25.50 | | | | |--- Age <= 61.00 | | | | | |--- class: True | | | | |--- Age > 61.00 | | | | | |--- class: False | | |--- Body mass index > 29.85 | | | |--- Diastolic blood_pressure <= 67.00 | | | | |--- class: True | | | |--- Diastolic blood_pressure > 67.00 | | | | |--- Two hour serum insulin <= 661.50 | | | | | |--- class: True | | | | |--- Two hour serum insulin > 661.50 | | | | | |--- class: False
# this code is developed by Miguel Leal and Gian Paolo Bulleddu
# Visualizing the alternative reduced tree with min_samples_leaf=3
print(export_text(dtcv2, feature_names=pima_fnames))
dot_datav2 = tree.export_graphviz(dtcv2, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames)
graphv2 = graphviz.Source(dot_datav2)
graphv2#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
Explanation:
In this exercise we have built two decision tree classifiers by variyng some hyperparameters. In both models we have kept the maximum depth to 5 then in the first model we have chamged the Gini impurity criterion to Entropy while in the second model we have set the minimum number of leaf to three. The purpose of this task is to evaluate how the decision logic differs from the predefined decision trees. After the training of the newly created models we have analyzed the results and down below are our conclusion.
First of all we have not detected big structure difference between the new (dtcv1, dtcv2) and the original trees (dtc, dtc5).
The plasma glucose concentration feature most frequently appears as the root node or very close to it in all models, this suggest that this is the most important feature.
The Body mass index too is often selected in the upper levels of the trees frequently acting as a secondary decision variable .
Ageand diabetes pedigree are most frequently selected at an intermediate tree level.
The rest of the dataset features like insulin,diastolic blood pressure,triceps skin fold thickness have a tendency to appear only in the deeper levels of the trees
or not at all so we can conclude that those are the less important features.
The Entropy splitting criterion has not lead the model to big changes in tree structures and split thresholds, moreover it has not changed the importance or the ordering of the dataset features. The entropy-based tree selects similar variables at the top of the tree and produces comparable decision paths,regardless of the impurity criterion used the the most important features can be identified by the models.
Limiting the minimum number of sample per leaf produces simpler trees , however even reducing highly specific leaves and eliminating splits based on few data observations
the structure of the tree looks almost unchanged, especially at tree top levels.
The main difference in this model tree does not have any node branched by the feature times being pregnant.
In conclusion all trained trees don't show big differences in structures and logic, the prediction of diabetes diagnosis looks mainly driven by features plasma glucose concentration and Body mass index,Ageand diabetes pedigree are important as well but with a lower strenght in routing the model to a good prediction.
3.4. Using SHAP Values for interpretability¶
We can utilize SHAP values to tackle some of the shortcomings mentioned on the previous section:
- SHAP is much more practical when dealing with, for example, a very deep tree with many nested conditions, as it offers a single importance value per feature. In this sense, it summarizes the explanatory information that is spread among the branches of the tree, although SHAP values do not necessarily offer the same interpretation you would extract from a tree's branches.
- SHAP, as a post-hoc explainability method, can endow non-interpretable models (such as Random Forests) with interpretability.
SHAP still has a problem with semantic significance (i.e. features having meaning). Because it only gauges the importance of features, we are still fully dependant on those features (and their impact) offering understandable insight on the problem.
We'll see how to generate both global and local explanations with SHAP:
Global explanations¶
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Building the explainers and computing Shapley values
# Unbound DT
explainerTreeC = shap.TreeExplainer(dtc)
shap_values_DTC = explainerTreeC(x_test)
# Limited-depth DT
explainerTree5C = shap.TreeExplainer(dtc5)
shap_values_DTC5 = explainerTree5C(x_test)
# Random Forest
explainerTreeRFC = shap.TreeExplainer(rfc)
shap_values_RFC = explainerTreeRFC(x_test)
explainers = [explainerTreeC, explainerTree5C, explainerTreeRFC]
# A general commentary on the shap_values_{model} indexing: the SHAP library interprets any classification as multi-class classification (you will see later that indexing is not needed for regressors), which means we need to specify
# the class we're explaining for. Then, [a, b, c] ---> a selects the rows (instances on the dataset), b selects the columns (variables on the dataset), c selects the class (so, 0 for tested_negative, 1 for tested_positive).
# You may alter the plots to work with limited instance ranges, different column sets, or even for the opposite class.
shap.summary_plot(shap_values_DTC[:,:,1])
shap.summary_plot(shap_values_DTC5[:,:,1])
shap.summary_plot(shap_values_RFC[:,:,1])
## this code is developed by Pablo Miguel Perez-Ferreiro
# Equivalent for Regressors
# Unbound DT
explainerTreeR = shap.TreeExplainer(dtr)
shap_values_DTR = explainerTreeR(x_test)
# Limited-depth DT
explainerTree5R = shap.TreeExplainer(dtr5)
shap_values_DTR5 = explainerTree5R(x_test)
# Random Forest
explainerTreeRFR = shap.TreeExplainer(rfr)
shap_values_RFR = explainerTreeRFR(x_test)
explainers_r = [explainerTreeR, explainerTree5R, explainerTreeRFR]
shap.summary_plot(shap_values_DTR)
shap.summary_plot(shap_values_DTR5)
shap.summary_plot(shap_values_RFR)
Exercise I1.2¶
Now you have an alternative explanation tool against which you can contrast the insight you extracted on Exercise I1.1, and some knowledge onto how the Random Forest is predicting. Try to interpret the graphs corresponding to the Classifier alternatives:
- Do the SHAP values for the Decision Trees match your inspection of their graphical representation?
- Do you find significant differences between the Decision Trees and the Random Forest?
It is a good thing if you read SHAP documentation in order to understand better how to interpret the generated plots. Keep in mind that the answer for any of these questions may be negative, but you should still try to theorise as to why (justify your responses).
Do the SHAP values for the Decision Trees match your inspection of their graphical representation?
Yes, the SHAP values confirm what we saw in the decision tree. Plasma glucose concentration has the highest impact on the predictions, just like it appears at the top of the tree splits. Other features, such as body mass index, age, and diabetes pedigree function, also show significant influence, which matches their presence in the higher levels of the tree. The SHAP summary plot shows that high values (red points) of plasma glucose generally push predictions toward the positive class (dots on the right), consistent with the treeās decision rules.
Do you find significant differences between the Decision Trees and the Random Forest?
Yes, there are differences between the two types of models. The Random Forest exhibits less dispersion in SHAP values, with points more densely grouped and fewer outliers. This shows the ensemble nature of the Random Forest, where predictions are averaged over many trees, reducing variability and sensitivity to individual data points. In contrast, single Decision Trees tend to produce more scattered SHAP values and sharper transitions due to their reliance on hard decision thresholds.
Local explanations¶
Keep in mind that, for a local explanation, we may justify a wrong prediction. This can be misleading, so please pay close attention to it.
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
for instance in [0, 24, 59, 76]: # specific cases, feel free to alter them
print('----------------------------------------------------------------------------------------------------------------------------')
print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_test.iloc[instance]}')
for model, model_name, explainer in zip(models, models_names, explainers):
print(f'The predicted output class by model {model_name} is: {model.predict(x_test.iloc[instance : instance+1])[0]}')
shap_values = explainer(x_test.iloc[instance : instance+1])
shap.summary_plot(shap_values[:,:,1])
print('----------------------------------------------------------------------------------------------------------------------------')
---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 0: The real output class is: False The predicted output class by model TREE is: True
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model TREE5 is: False
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF is: False
---------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 24: The real output class is: False The predicted output class by model TREE is: False
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model TREE5 is: False
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF is: False
---------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 59: The real output class is: True The predicted output class by model TREE is: False
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model TREE5 is: False
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF is: True
---------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 76: The real output class is: True The predicted output class by model TREE is: True
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model TREE5 is: True
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF is: True
----------------------------------------------------------------------------------------------------------------------------
## this code is developed by Pablo Miguel Perez-Ferreiro
# We will only execute one instance for regression to avoid cluttering the notebook.
for instance in [0]:
print('----------------------------------------------------------------------------------------------------------------------------')
print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_test_num[instance]}')
for model, model_name, explainer in zip(models_r, models_r_names, explainers_r):
print(f'The predicted output class by model {model_name} is: {model.predict(x_test.iloc[instance : instance+1])[0]}')
shap_values = explainer(x_test.iloc[instance : instance+1])
shap.summary_plot(shap_values)
print('----------------------------------------------------------------------------------------------------------------------------')
---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 0: The real output class is: 0.0 The predicted output class by model TREE-R is: 1.0
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model TREE5-R is: 0.25
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF-R is: 0.381
----------------------------------------------------------------------------------------------------------------------------
Exercise I1.3¶
Have a look at SHAP's API reference. So far, we have only used the basic summary plot, but one of the library's greatest strengths is its rich plotting environment. Try your hand at generating new plots (both global and local) following the code given during the section. Do you think they offer better information than the summaries we were using? Which plotting options do you think would be the best at explaining the problem at hand to a layperson? Would your answer change if the explanations were meant for an expert?
# You may need to execute the following for some of the plots to work
shap.initjs()
for instance in [24]: # specific cases, feel free to alter them
print('----------------------------------------------------------------------------------------------------------------------------')
print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_test.iloc[instance]}')
for model, model_name, explainer in zip(models, models_names, explainers):
print(f'The predicted output class by model {model_name} is: {model.predict(x_test.iloc[instance : instance+1])[0]}')
shap_values = explainer(x_test.iloc[instance : instance+1])
# select class 1
shap_class = shap_values[:, :, 1]
print("Bar:")
shap.plots.bar(shap_class)
print("Waterfall:")
shap.plots.waterfall(shap_class[0])
print("Partial dependence:")
shap.plots.partial_dependence("Age", model.predict ,x_test)
print('----------------------------------------------------------------------------------------------------------------------------')
---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 24: The real output class is: False The predicted output class by model TREE is: False Bar:
Waterfall:
Partial dependence:
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model TREE5 is: False Bar:
Waterfall:
Partial dependence:
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF is: False Bar:
Waterfall:
Partial dependence:
----------------------------------------------------------------------------------------------------------------------------
Explanation:
Yes, these plots do offer better and more complete information than the summary plot alone, however it has to be seen as a complementary rather than substitutive way. The SHAP summary plot helps in understandin the global behavior of the model, it highlights which the most importance features are and how their values influence the predictions. However, it is quite abstract and does not explain individual model decisions.
The bar and waterfall plots provide clearer local explanations. The bar plot makes it easy to see which features contribute the most to a single prediction, while starting from a baseline ,the waterfall plot explicitally shows how each feature influences the modelās output toward or away from the predicted class. This makes the decision process more interpretable at the instance level and helps understand why a specific patient is classified as diabetic or not. Partial dependence plots add another perspective by showing how a feature affects predictions on average across the dataset, which can be useful to understand general trends.
For a layperson, the bar and especially the waterfall plots are the most effective. They are intuitive, visually clear, and allow to easily explain which factors increase the risk and which decrease it for a given individual. In contrast, the summary plot and partial dependence plots can be harder to interpret without technical background.
If the explanations were meant for an expert, the answer would change. An expert audience would benefit more from the summary plot and partial dependence plots, as these provide global insights into model behavior, feature interactions, and overall consistency. In that case, local plots would still be useful, but mainly as a complement for analyzing specific cases or debugging the model rather than as the primary explanation tool.
3.5. Trade-offs¶
We have now studied two ways of sheding light on how to make ML models more interpretable: either through direct inspection when the model allows it (because it is interpretable-by-design), or applying SHAP as a post-hoc method. It is important, however, to note that interpretability, while very important, is not to be gauged in a vacuum. A very interpretable model that performs poorly is useless: it may explain its reasoning, but the reasoning is flawed in itself and there is no point in understanding it.
In a less extreme situation, we may find ourselves faced with a decision between a highly interpretable model that performs reasonably well, and a less interpretable model that performs even better. Whatever we choose, we have a trade-off, and a good way to decide is through the construction of a Pareto front that allows us to objectively compare the relative virtues of all options, presenting a performance metric against an interpretability metric:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Building and visualizing the Pareto front with predictive accuracy versus number of leaf nodes, which can be understood as a surrogate for interpretability (remember
# how the largest tree was much harder to visually inspect)
limits = [min([dtc.get_n_leaves(), dtc5.get_n_leaves()])*0.75, max([dtc.get_n_leaves(), dtc5.get_n_leaves()])*1.5]
# Random Forest is a black box model and there is no easy, fair way to give it a nominal leaf node amount. We settle here for multiplying the largest tree x1.5 (penalty for opaqueness).
x_axis = [dtc.get_n_leaves(), dtc5.get_n_leaves(), limits[1]]
print("The data for the Pareto front is as follows:")
for name, accuracy, leaves in zip(models_names, models_acc, x_axis):
print(f'\t-For model {name}, accuracy is {accuracy} with a total of {leaves} leaf nodes.')
plt.figure(figsize=[15,10])
plot_pareto_front(x_axis, models_acc, models_names, 'Accuracy (Classification Ratio)','Interpretability (Num of rules / leaves)', limits[0], limits[1])
The data for the Pareto front is as follows: -For model TREE, accuracy is 0.688 with a total of 132 leaf nodes. -For model TREE5, accuracy is 0.727 with a total of 27 leaf nodes. -For model RF, accuracy is 0.805 with a total of 198.0 leaf nodes.
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Building and visualizing the Pareto front with Accuracy versus Shap Length
x_axis = [get_shap_explanation_length(explainer.shap_values(x_test)[:,:,1]) for explainer in explainers]
limits = [min(x_axis)*0.99, max(x_axis)*1.01]
print("The data for the Pareto front is as follows:")
for name, accuracy, shap_l in zip(models_names, models_acc, x_axis):
print(f'\t-For model {name}, accuracy is {accuracy} with a SHAP length of {shap_l}.')
plt.figure(figsize=[15,10])
plot_pareto_front(x_axis,models_acc,models_names,'Accuracy (Classification Ratio)','Interpretability (Shap Length)',limits[0],limits[1])
The data for the Pareto front is as follows: -For model TREE, accuracy is 0.688 with a SHAP length of 558. -For model TREE5, accuracy is 0.727 with a SHAP length of 561. -For model RF, accuracy is 0.805 with a SHAP length of 557.
4. A more integral approach: InterpretML¶
In the previous section, we followed the full process of building an intelligent system and trying to understand how it works. However, thanks to the growing interest in AI trustworthiness and explainability (which motivates this subject!), we now have many tools at our disposal that make the life of a trustworthy AI engineer much easier. The InterpretML Python library is intended to do just that; play around with it to round up this first practical session.
## this code is developed by Jose Maria Alonso-Moral
# This is InterpretML's own approach: a 'glass-box' model that aims to provide opaque-model performance with full explainability.
# As such, it is already designed with explanations in mind, and supports them natively.
ebm = ExplainableBoostingClassifier()
ebm.fit(x_tr, y_tr)
# This will provide global insight on the model. You can change tabs to study the effect of individual variables, and also of interactions between pairs of them.
ebm_global = ebm.explain_global()
show(ebm_global)
## this code is developed by Jose Maria Alonso-Moral
# This code explains local instances
instance = 0
ebm_local = ebm.explain_local(x_test.iloc[instance : instance+1], y_test.iloc[instance : instance+1])
show(ebm_local)
Exercise I1.4¶
Now that you've reached the end of this first practical session, you have all the tools needed to autonomously perform the following task. Use the code below to import the SONGS dataset, present on your workspace, which contains 2017 data instances representing songs from a Kaggle competition. The dataset is intended to serve for the classification problem of predicting if the song will be 'Like' or 'Dislike'.
With the dataset imported, follow the process laid out through this notebook to explain the global behaviour of the model. Then, find a song you like (and another song that you don't like) on the dataset (using the code provided), and locally explain the predictions made by a decision tree with a good interpretability-accuracy trade-off and Random Forest. Discuss all your findings.
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# importing the dataset
file_songs = 'testlib/SONGS/SONGS.arff'
file_songs_data = 'testlib/SONGS/spotifyData.csv'
with open(file_songs) as f:
songs_data_value, songs_attributes = loadarff(f)
f.close()
with open(file_songs_data) as f:
song_names = pd.read_csv(file_songs_data).loc[:, ['artist', 'song_title', 'id']]
f.close()
songs_fnames=["acoustics","danceability","duration","energy","instrumentalness","key","liveness","loudness","mode","speechiness","tempo","timeSignature","valence"]
songs_class_names = ['Dislike','Like']
# You don't need to perform a train/test split or validation of the models later.
songs_class_names = np.array(songs_class_names)
songs_feature_names = np.array(songs_attributes.names())
df_songs=pd.DataFrame(songs_data_value)
df_songs.columns = songs_feature_names
songs_target = df_songs.pop('class')
songs_target_onehot = pd.get_dummies(songs_target)[b'2']
X_songs = df_songs
y_songs = songs_target_onehot
# You don't need the numerical variant, as you don't need to do the Regression equivalents.
print(f'Class names for the PIMA Dataset:\n\t{", ".join(list(songs_class_names))}\nFeature names for the PIMA Dataset:\n\t{", ".join(list(songs_feature_names))}')
Class names for the PIMA Dataset: Dislike, Like Feature names for the PIMA Dataset: acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, mode, speechiness, tempo, time_signature, valence, class
# You can use this code to search for text on either the artist or the title of the song and record its ID, to search for it on the actual prediction dataset later.
# If you can't seem to find an artist or song, it may not be on the dataset (which is not that exhaustive) or you may be writing it differently to the dataset's record
# in which case you can try to spell it differently (such as 'michael' instead of 'Michael'). A trick you can use, because str.contains() treats its input like a regular
# expression, is to write just the middle part: instead of 'Billie', just write 'illie' to avoid upper/lowercase conflicts.
artist_songs = song_names[song_names['artist'].str.contains("Michael Jackson")]
named_songs = song_names[song_names['song_title'].str.contains("Hips Don")]
print(artist_songs)
print(named_songs)
artist song_title id
1827 Michael Jackson Billie Jean 1827
1828 Michael Jackson Beat It - Single Version 1828
1829 Michael Jackson Black or White - Single Version 1829
1830 Michael Jackson The Way You Make Me Feel 1830
1831 Michael Jackson Man In The Mirror 1831
1832 Michael Jackson P.Y.T. (Pretty Young Thing) 1832
1920 Michael Jackson Remember the Time 1920
1948 Michael Jackson Earth Song - Remastered Version 1948
artist song_title id
1927 Shakira Hips Don't Lie 1927
# 1600 instances for training, which is approximately the 80% of 2017
X_songs_train = X_songs.iloc[0:1600, :]
X_songs_test = X_songs.iloc[1600:, :]
y_songs_train = y_songs.iloc[0:1600]
y_songs_test = y_songs.iloc[1600:]
print(X_songs.iloc[1927])
print(X_songs_test.iloc[1927-1600])
acousticness 0.2840 danceability 0.7780 duration_ms 218093.0000 energy 0.8240 instrumentalness 0.0000 key 10.0000 liveness 0.4050 loudness -5.8920 mode 0.0000 speechiness 0.0712 tempo 100.0240 time_signature 4.0000 valence 0.7670 Name: 1927, dtype: float64 acousticness 0.2840 danceability 0.7780 duration_ms 218093.0000 energy 0.8240 instrumentalness 0.0000 key 10.0000 liveness 0.4050 loudness -5.8920 mode 0.0000 speechiness 0.0712 tempo 100.0240 time_signature 4.0000 valence 0.7670 Name: 1927, dtype: float64
# We fit a Decision Tree, and a Random Forest.
dtc_songs = tree.DecisionTreeClassifier(max_depth=5)
dtc_songs.fit(X_songs_train, y_songs_train)
rfc_songs = RandomForestClassifier(n_estimators=1000)
rfc_songs.fit(X_songs_train, y_songs_train)
models_songs = [dtc_songs, rfc_songs]
models_names_songs = ['TREE_songs', 'RF_songs']
# Cross-validation for the models just fit above
print("1) Cross-validation (over training data)")
scorings = ['accuracy', 'f1'] # For binary classification
nF= 5
for model, model_name in zip(models_songs, models_names_songs):
cv_results = cross_validate(model, X_songs_train, y_songs_train, cv=nF,
scoring=scorings,
return_train_score=False)
print(f'\n\t{model_name}:')
print(f'\t\tCorrect Classification Rate [Average (St. Dev)] = {np.mean(cv_results["test_accuracy"]):.3f} ({np.std(cv_results["test_accuracy"]):.3f})')
print(f'\t\tF-Score [Average (St. Dev)] = {np.mean(cv_results["test_f1"]):.3f} ({np.std(cv_results["test_f1"]):.3f})')
# Test with unknown instances
print("\n\n2) Test (with previously unseen data)")
models_acc_songs=[]
target_names = ['class 1', 'class 2']
for model, model_name in zip(models_songs, models_names_songs):
sc = round(model.score(X_songs_test, y_songs_test), 3) # we round the accuracy to 3 decimals and append it to the list of model accuracies for the test set
models_acc_songs.append(sc)
# Mean accuracy of self.predict(x_test) wrt y_test
print(f'\n\t{model_name}:')
print(f'\t\tCorrect Classification Rate: {models_acc_songs[-1]:.3f}')
y_pred= model.predict(X_songs_test)
print(classification_report(y_songs_test, model.predict(X_songs_test), target_names=target_names))
1) Cross-validation (over training data)
TREE_songs:
Correct Classification Rate [Average (St. Dev)] = 0.676 (0.041)
F-Score [Average (St. Dev)] = 0.761 (0.029)
RF_songs:
Correct Classification Rate [Average (St. Dev)] = 0.727 (0.054)
F-Score [Average (St. Dev)] = 0.811 (0.030)
2) Test (with previously unseen data)
TREE_songs:
Correct Classification Rate: 0.542
precision recall f1-score support
class 1 1.00 0.54 0.70 417
class 2 0.00 0.00 0.00 0
accuracy 0.54 417
macro avg 0.50 0.27 0.35 417
weighted avg 1.00 0.54 0.70 417
RF_songs:
Correct Classification Rate: 0.458
precision recall f1-score support
class 1 1.00 0.46 0.63 417
class 2 0.00 0.00 0.00 0
accuracy 0.46 417
macro avg 0.50 0.23 0.31 417
weighted avg 1.00 0.46 0.63 417
print(export_text(dtc_songs, feature_names=songs_feature_names[0:-1]))
dot_data_tree_songs = tree.export_graphviz(dtc_songs, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=songs_feature_names[0:-1])
graph_tree_songs = graphviz.Source(dot_data_tree_songs)
graph_tree_songs#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- energy <= 0.20 | |--- tempo <= 182.99 | | |--- danceability <= 0.15 | | | |--- class: True | | |--- danceability > 0.15 | | | |--- instrumentalness <= 0.00 | | | | |--- speechiness <= 0.05 | | | | | |--- class: True | | | | |--- speechiness > 0.05 | | | | | |--- class: False | | | |--- instrumentalness > 0.00 | | | | |--- duration_ms <= 173942.50 | | | | | |--- class: False | | | | |--- duration_ms > 173942.50 | | | | | |--- class: False | |--- tempo > 182.99 | | |--- class: True |--- energy > 0.20 | |--- loudness <= -5.63 | | |--- instrumentalness <= 0.00 | | | |--- speechiness <= 0.08 | | | | |--- danceability <= 0.78 | | | | | |--- class: False | | | | |--- danceability > 0.78 | | | | | |--- class: True | | | |--- speechiness > 0.08 | | | | |--- energy <= 0.34 | | | | | |--- class: False | | | | |--- energy > 0.34 | | | | | |--- class: True | | |--- instrumentalness > 0.00 | | | |--- acousticness <= 0.86 | | | | |--- instrumentalness <= 0.00 | | | | | |--- class: True | | | | |--- instrumentalness > 0.00 | | | | | |--- class: True | | | |--- acousticness > 0.86 | | | | |--- energy <= 0.39 | | | | | |--- class: False | | | | |--- energy > 0.39 | | | | | |--- class: True | |--- loudness > -5.63 | | |--- duration_ms <= 259198.00 | | | |--- duration_ms <= 165915.00 | | | | |--- valence <= 0.18 | | | | | |--- class: False | | | | |--- valence > 0.18 | | | | | |--- class: True | | | |--- duration_ms > 165915.00 | | | | |--- instrumentalness <= 0.00 | | | | | |--- class: False | | | | |--- instrumentalness > 0.00 | | | | | |--- class: True | | |--- duration_ms > 259198.00 | | | |--- acousticness <= 0.00 | | | | |--- class: False | | | |--- acousticness > 0.00 | | | | |--- instrumentalness <= 0.00 | | | | | |--- class: True | | | | |--- instrumentalness > 0.00 | | | | | |--- class: True
explainerTreeSongs = shap.TreeExplainer(dtc_songs)
shap_values_DTC_songs = explainerTreeSongs(X_songs_test)
explainerRFSongs = shap.TreeExplainer(rfc_songs)
shap_values_RFC_songs = explainerRFSongs(X_songs_test)
explainers_songs = [explainerTreeSongs, explainerRFSongs]
shap.summary_plot(shap_values_DTC_songs[:,:,1])
shap.summary_plot(shap_values_RFC_songs[:,:,1])
Global explanation:
In analyzing the decision tree we have observed that the feature energy is the root of the tree and drives the first split so that we can coclude that it is the most important feature for the model.
This fact can suggest that the model distinguishes songs depending on how much energetic they are.
After energy, even tempo and loudness features play an important role in the model decision logic, while features such as danceability and duration are used to refine the decisions at deeper levels of the tree, these last features contribute to the final classification but are less importsnt than energy, tempo, and loudness.
The SHAP summary plot provides a complementary point of view. According to SHAP plot, instrumentalness and loudness are the most important features . In particular, for loudness, it is clear that higher values tend to force the prediction toward the āDislikeā class, very loud songs are generally less preferred by the model.
Moreover, the SHAP plot shows some outliers related to energy, where very low energy values strongly contribute to a āDislikeā prediction.
This shows that, even if energy is an important feature in the treeās decisions, very high or very low energy values can strongly affect the modelās predictions.
When comparing the SHAP summary plots of the Decision Tree and the Random Forest, we can notice small differences. Since the Random Forest combines the predictions of many trees, it produces a smoother distribution of SHAP values, which reduces variability and the influence of individual splits.
In contrast, the Decision Tree assigns zero or near-zero importance to some features, since a single tree may never use them in its splits. This explains why certain features appear unimportant in the Decision Tree SHAP plot but still have importance in the Random Forest.
In the end,both models are able to identify the same factors that influence musical preferences,the Random Forest provides a more reliable and detailed explanation ,while the Decison Tree provides simpler and easier to understand decision rules.
for instance in [1828-1600, 1927-1600]: # specific cases, feel free to alter them
print('----------------------------------------------------------------------------------------------------------------------------')
print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_songs_test.iloc[instance]}')
for model, model_name, explainer in zip(models_songs, models_names_songs, explainers_songs):
print(f'The predicted output class by model {model_name} is: {model.predict(X_songs_test.iloc[instance : instance+1])[0]}')
shap_values = explainer(X_songs_test.iloc[instance : instance+1])
shap_class = shap_values[:, :, 1]
print("Summary:")
shap.summary_plot(shap_class)
print("Bar:")
shap.plots.bar(shap_class)
print("Waterfall:")
shap.plots.waterfall(shap_class[0])
print("Partial dependence:")
shap.plots.partial_dependence("loudness", model.predict ,X_songs_test)
print('----------------------------------------------------------------------------------------------------------------------------')
---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 228: The real output class is: False The predicted output class by model TREE_songs is: False Summary:
Bar:
Waterfall:
Partial dependence:
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF_songs is: True Summary:
Bar:
Waterfall:
Partial dependence:
---------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------- WORKING WITH INSTANCE 327: The real output class is: False The predicted output class by model TREE_songs is: True Summary:
Bar:
Waterfall:
Partial dependence:
---------------------------------------------------------------------------------------------------------------------------- The predicted output class by model RF_songs is: False Summary:
Bar:
Waterfall:
Partial dependence:
----------------------------------------------------------------------------------------------------------------------------
Local explanation:
The selected songs for local analysis are Beat It by Michael Jackson (ID 1828 in the original dataset, corresponding to index 228 in the test set) and Hips Donāt Lie by Shakira (ID 1927, corresponding to index 327 in the test set). Both songs were analyzed using the Decision Tree and the Random Forest in order to compare their local explanations.
For Beat It, the local SHAP explanations show that loudness and instrumentalness are the most influential features, and both contribute negatively to the probability of the song being liked. This suggests that, for this specific instance, high loudness levels and low instrumental content push the prediction toward the āDislikeā class. On the other hand, danceability has a positive contribution, partially compensating for the negative effect of loudness and instrumentalness and supporting the āLikeā prediction. This combination of features reflects a trade-off between rhythmic appeal and production characteristics in the modelās decision.
In the case of Hips Donāt Lie, the same three featuresāloudness, danceability, and instrumentalnessāalso appear as the most important. However, their effects differ depending on the model. In the Decision Tree, loudness contributes positively to the prediction, indicating that higher loudness levels increase the likelihood of the song being liked for this instance. When analyzing the Random Forest, the importance of loudness and danceability decreases, while instrumentalness becomes the dominant factor by a large margin. This change reflects the more stable and averaged behavior of the Random Forest, which smooths the influence of individual features and relies more heavily on consistent patterns across many trees.
Overall, the local explanations highlight how the same features can have different impacts depending on both the specific song and the model used, emphasizing the value of local interpretability tools such as SHAP.
# Building and visualizing the Pareto front with predictive accuracy versus number of leaf nodes, which can be understood as a surrogate for interpretability (remember
# how the largest tree was much harder to visually inspect)
limits = [min([dtc_songs.get_n_leaves()])*0.75, max([dtc_songs.get_n_leaves()])*1.5]
# Random Forest is a black box model and there is no easy, fair way to give it a nominal leaf node amount. We settle here for multiplying the largest tree x1.5 (penalty for opaqueness).
x_axis_songs = [dtc_songs.get_n_leaves(), limits[1]]
print("The data for the Pareto front is as follows:")
for name, accuracy, leaves in zip(models_names_songs, models_acc_songs, x_axis_songs):
print(f'\t-For model {name}, accuracy is {accuracy} with a total of {leaves} leaf nodes.')
plt.figure(figsize=[15,10])
plot_pareto_front(x_axis_songs, models_acc_songs, models_names_songs, 'Accuracy (Classification Ratio)','Interpretability (Num of rules / leaves)', limits[0], limits[1])
The data for the Pareto front is as follows: -For model TREE_songs, accuracy is 0.542 with a total of 21 leaf nodes. -For model RF_songs, accuracy is 0.458 with a total of 31.5 leaf nodes.
# Building and visualizing the Pareto front with Accuracy versus Shap Length
x_axis_songs = [get_shap_explanation_length(explainer.shap_values(X_songs_test)[:,:,1]) for explainer in explainers_songs]
limits = [min(x_axis_songs)*0.99, max(x_axis_songs)*1.01]
print("The data for the Pareto front is as follows:")
for name, accuracy, shap_l in zip(models_names_songs, models_acc_songs, x_axis_songs):
print(f'\t-For model {name}, accuracy is {accuracy} with a SHAP length of {shap_l}.')
plt.figure(figsize=[15,10])
plot_pareto_front(x_axis_songs,models_acc_songs,models_names_songs,'Accuracy (Classification Ratio)','Interpretability (Shap Length)',limits[0],limits[1])
The data for the Pareto front is as follows: -For model TREE_songs, accuracy is 0.542 with a SHAP length of 4879. -For model RF_songs, accuracy is 0.458 with a SHAP length of 4866.